|
""" |
|
Class to handle llm wildcard routing and regex pattern matching |
|
""" |
|
|
|
import copy |
|
import re |
|
from re import Match |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
from litellm import get_llm_provider |
|
from litellm._logging import verbose_router_logger |
|
|
|
|
|
class PatternUtils: |
|
@staticmethod |
|
def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]: |
|
""" |
|
Calculate pattern specificity based on length and complexity. |
|
|
|
Args: |
|
pattern: Regex pattern to analyze |
|
|
|
Returns: |
|
Tuple of (length, complexity) for sorting |
|
""" |
|
complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"] |
|
ret_val = ( |
|
len(pattern), |
|
sum( |
|
pattern.count(char) for char in complexity_chars |
|
), |
|
) |
|
return ret_val |
|
|
|
@staticmethod |
|
def sorted_patterns( |
|
patterns: Dict[str, List[Dict]] |
|
) -> List[Tuple[str, List[Dict]]]: |
|
""" |
|
Cached property for patterns sorted by specificity. |
|
|
|
Returns: |
|
Sorted list of pattern-deployment tuples |
|
""" |
|
return sorted( |
|
patterns.items(), |
|
key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]), |
|
reverse=True, |
|
) |
|
|
|
|
|
class PatternMatchRouter: |
|
""" |
|
Class to handle llm wildcard routing and regex pattern matching |
|
|
|
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing |
|
|
|
This class will store a mapping for regex pattern: List[Deployments] |
|
""" |
|
|
|
def __init__(self): |
|
self.patterns: Dict[str, List] = {} |
|
|
|
def add_pattern(self, pattern: str, llm_deployment: Dict): |
|
""" |
|
Add a regex pattern and the corresponding llm deployments to the patterns |
|
|
|
Args: |
|
pattern: str |
|
llm_deployment: str or List[str] |
|
""" |
|
|
|
regex = self._pattern_to_regex(pattern) |
|
if regex not in self.patterns: |
|
self.patterns[regex] = [] |
|
self.patterns[regex].append(llm_deployment) |
|
|
|
def _pattern_to_regex(self, pattern: str) -> str: |
|
""" |
|
Convert a wildcard pattern to a regex pattern |
|
|
|
example: |
|
pattern: openai/* |
|
regex: openai/.* |
|
|
|
pattern: openai/fo::*::static::* |
|
regex: openai/fo::.*::static::.* |
|
|
|
Args: |
|
pattern: str |
|
|
|
Returns: |
|
str: regex pattern |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
return re.escape(pattern).replace(r"\*", "(.*)") |
|
|
|
def _return_pattern_matched_deployments( |
|
self, matched_pattern: Match, deployments: List[Dict] |
|
) -> List[Dict]: |
|
new_deployments = [] |
|
for deployment in deployments: |
|
new_deployment = copy.deepcopy(deployment) |
|
new_deployment["litellm_params"]["model"] = ( |
|
PatternMatchRouter.set_deployment_model_name( |
|
matched_pattern=matched_pattern, |
|
litellm_deployment_litellm_model=deployment["litellm_params"][ |
|
"model" |
|
], |
|
) |
|
) |
|
new_deployments.append(new_deployment) |
|
|
|
return new_deployments |
|
|
|
def route( |
|
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None |
|
) -> Optional[List[Dict]]: |
|
""" |
|
Route a requested model to the corresponding llm deployments based on the regex pattern |
|
|
|
loop through all the patterns and find the matching pattern |
|
if a pattern is found, return the corresponding llm deployments |
|
if no pattern is found, return None |
|
|
|
Args: |
|
request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned. |
|
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names |
|
Returns: |
|
Optional[List[Deployment]]: llm deployments |
|
""" |
|
try: |
|
if request is None: |
|
return None |
|
|
|
sorted_patterns = PatternUtils.sorted_patterns(self.patterns) |
|
regex_filtered_model_names = ( |
|
[self._pattern_to_regex(m) for m in filtered_model_names] |
|
if filtered_model_names is not None |
|
else [] |
|
) |
|
for pattern, llm_deployments in sorted_patterns: |
|
if ( |
|
filtered_model_names is not None |
|
and pattern not in regex_filtered_model_names |
|
): |
|
continue |
|
pattern_match = re.match(pattern, request) |
|
if pattern_match: |
|
return self._return_pattern_matched_deployments( |
|
matched_pattern=pattern_match, deployments=llm_deployments |
|
) |
|
except Exception as e: |
|
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}") |
|
|
|
return None |
|
|
|
@staticmethod |
|
def set_deployment_model_name( |
|
matched_pattern: Match, |
|
litellm_deployment_litellm_model: str, |
|
) -> str: |
|
""" |
|
Set the model name for the matched pattern llm deployment |
|
|
|
E.g.: |
|
|
|
Case 1: |
|
model_name: llmengine/* (can be any regex pattern or wildcard pattern) |
|
litellm_params: |
|
model: openai/* |
|
|
|
if model_name = "llmengine/foo" -> model = "openai/foo" |
|
|
|
Case 2: |
|
model_name: llmengine/fo::*::static::* |
|
litellm_params: |
|
model: openai/fo::*::static::* |
|
|
|
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz" |
|
|
|
Case 3: |
|
model_name: *meta.llama3* |
|
litellm_params: |
|
model: bedrock/meta.llama3* |
|
|
|
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b" |
|
""" |
|
|
|
|
|
if "*" not in litellm_deployment_litellm_model: |
|
return litellm_deployment_litellm_model |
|
|
|
wildcard_count = litellm_deployment_litellm_model.count("*") |
|
|
|
|
|
dynamic_segments = matched_pattern.groups() |
|
|
|
if len(dynamic_segments) > wildcard_count: |
|
return ( |
|
matched_pattern.string |
|
) |
|
|
|
for segment in dynamic_segments: |
|
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace( |
|
"*", segment, 1 |
|
) |
|
|
|
return litellm_deployment_litellm_model |
|
|
|
def get_pattern( |
|
self, model: str, custom_llm_provider: Optional[str] = None |
|
) -> Optional[List[Dict]]: |
|
""" |
|
Check if a pattern exists for the given model and custom llm provider |
|
|
|
Args: |
|
model: str |
|
custom_llm_provider: Optional[str] |
|
|
|
Returns: |
|
bool: True if pattern exists, False otherwise |
|
""" |
|
if custom_llm_provider is None: |
|
try: |
|
( |
|
_, |
|
custom_llm_provider, |
|
_, |
|
_, |
|
) = get_llm_provider(model=model) |
|
except Exception: |
|
|
|
pass |
|
return self.route(model) or self.route(f"{custom_llm_provider}/{model}") |
|
|
|
def get_deployments_by_pattern( |
|
self, model: str, custom_llm_provider: Optional[str] = None |
|
) -> List[Dict]: |
|
""" |
|
Get the deployments by pattern |
|
|
|
Args: |
|
model: str |
|
custom_llm_provider: Optional[str] |
|
|
|
Returns: |
|
List[Dict]: llm deployments matching the pattern |
|
""" |
|
pattern_match = self.get_pattern(model, custom_llm_provider) |
|
if pattern_match: |
|
return pattern_match |
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|